package com.iteaj.iot.client;

import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.mqtt.MqttConnAckMessage;
import io.netty.util.Attribute;
import io.netty.util.AttributeKey;

import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.function.Consumer;

/**
 * 多阶段链接
 */
public interface MultiStageConnect {

    AttributeKey<Boolean> CONN_INIT_FLAG = AttributeKey.valueOf("IOT:MSC:CONN_INIT_FLAG");

    Channel getChannel();

    ChannelPromise getConnectFinishedFlag();

    MultiStageConnect setConnectFinishedFlag(ChannelPromise promise);

    default ChannelPromise setSuccess() {
        this.getChannel().attr(CONN_INIT_FLAG).set(Boolean.TRUE);

        // 连接关闭时需要重置初始化标签
        this.getChannel().closeFuture().addListener(future -> {
            this.getChannel().attr(CONN_INIT_FLAG).set(null);
        });

        final ChannelPromise channelPromise = getConnectFinishedFlag();
        if(channelPromise != null) {
            this.setConnectFinishedFlag(null);
            channelPromise.setSuccess();
        }

        return channelPromise;
    }

    default ChannelPromise setFailure(Throwable cause) {
        final ChannelPromise channelPromise = getConnectFinishedFlag();
        if(channelPromise != null) {
            this.setConnectFinishedFlag(null);
            channelPromise.setFailure(cause);
        }

        return channelPromise;
    }

    /**
     * 多阶段连接
     * @param timeout
     * @return
     */
    default ChannelFuture stageConnect(long timeout) {
        // 说明此客户端已经发起了连接但是还没有完成整个初始化
        if(this.getConnectFinishedFlag() != null) {
            return this.getConnectFinishedFlag();
        }

        // 连接未初始化或者没有激活
        if(this.getChannel() == null || !this.getChannel().isActive()) {
            long ackTimeout = timeout == 0 ? 3000 : timeout;
            synchronized (this) {
                // 说明此客户端已经发起了连接但是还没有完成整个初始化
                if(this.getConnectFinishedFlag() != null) {
                    return this.getConnectFinishedFlag();
                }

                // 连接已经建立并且已经初始化完成
                if(this.getChannel() != null && Boolean.TRUE.equals(this.getChannel().attr(CONN_INIT_FLAG).get())) {
                    return this.getChannel().newSucceededFuture();
                }

                // 连接不存在或者已经断开 重新建立一个连接
                ChannelFuture connect = this.doStageConnect();
                final ChannelPromise returnPromise = connect.channel().newPromise();

                try {
                    if (!connect.await(ackTimeout, TimeUnit.MILLISECONDS)) {// socket连接失败
                        return returnPromise.setFailure(new TimeoutException("连接超时[" + ackTimeout + "]"));
                    }
                } catch (InterruptedException e) {
                    return returnPromise.setFailure(e);
                }

                Attribute<Boolean> attr = connect.channel().attr(CONN_INIT_FLAG);
                if(connect.isSuccess()) { // socket连接成功
                    // 如果连接成功先校验是否初始化完成
                    if(!Boolean.TRUE.equals(attr.get())) { // 还未初始化完成
                        this.setConnectFinishedFlag(returnPromise); // 用来标记此连接未初始化
                        // 在指定时间段内没有初始化完成表示初始化失败
                        connect.channel().eventLoop().schedule(() -> {
                            if(attr.get() == null) { // 初始化失败
                                this.setFailure(new TimeoutException("连接初始化超时("+ackTimeout+"ms)"));
                            } else if(Boolean.TRUE.equals(attr.get())) { // 初始化成功
                                if(getConnectFinishedFlag() != null) {
                                    this.setSuccess();
                                }
                            }
                        }, ackTimeout, TimeUnit.MILLISECONDS);
                    } else { // 已经初始化完成
                        return returnPromise.setSuccess();
                    }
                } else { // 连接失败
                    return returnPromise.setFailure(connect.cause());
                }

                return returnPromise;
            }

        } else {
            Attribute<Boolean> attr = this.getChannel().attr(CONN_INIT_FLAG);
            if(Boolean.TRUE.equals(attr.get())) { // 说明已经完成初始化
                return this.getChannel().newSucceededFuture();
            } else { // 未初始化完成
                final ChannelPromise promise = getConnectFinishedFlag();
                if(promise != null) {
                    return promise;
                } else if(Boolean.TRUE.equals(attr.get())) { // 已经完成且成功
                    return this.getChannel().newSucceededFuture();
                } else { // 已经完成但是失败了
                    return this.getChannel().newFailedFuture(new ClientProtocolException("初始化失败"));
                }
            }
        }
    }

    ChannelFuture doStageConnect();
}
